ARG IMAGE_NAME
FROM ${IMAGE_NAME}:{{ cuda.version.major }}.{{ cuda.version.minor }}-base-{{ cuda.os.distro }}{{ cuda.os.version }}{{ cuda.image_tag_suffix if "image_tag_suffix" in cuda }} as base

{% if "x86_64" in cuda %}
FROM base as base-amd64

ENV NV_CUDA_LIB_VERSION {{ cuda.x86_64.components.libraries.version }}
ENV NV_NVTX_VERSION {{ cuda.x86_64.components.nvtx.version }}
ENV NV_LIBNPP_VERSION {{ cuda.x86_64.components.libnpp.version }}

{% if "libcublas" in cuda.x86_64.components %}
    {% if cuda.version.major_minor in ["10.1", "10.2"] %}
        {% set suffix = "10" %}
    {% endif %}
{% endif %}
ENV NV_LIBCUBLAS_PACKAGE_NAME libcublas{{ suffix }}
ENV NV_LIBCUBLAS_VERSION {{ cuda.x86_64.components.libcublas.version }}
ENV NV_LIBCUBLAS_PACKAGE ${NV_LIBCUBLAS_PACKAGE_NAME}-${NV_LIBCUBLAS_VERSION}

{% if "libnccl2" in cuda.x86_64.components %}
{% if "source" in  cuda.x86_64.components.libnccl2 %}
ENV NV_LIBNCCL_DL_BASENAME nccl_{{ cuda.x86_64.components.libnccl2.version }}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}_x86_64.txz
ENV NV_LIBNCCL_DL_SHA256SUM {{ cuda.x86_64.components.libnccl2.sha256sum }}
ENV NV_LIBNCCL_DL_SOURCE {{ cuda.x86_64.components.libnccl2.source }}
{% else %}
    {% set nccl_package = 1 %}
ENV NV_LIBNCCL_PACKAGE_NAME libnccl
ENV NV_LIBNCCL_PACKAGE_VERSION {{ cuda.x86_64.components.libnccl2.version }}
ENV NCCL_VERSION {{ cuda.x86_64.components.libnccl2.version[:-2] }}
ENV NV_LIBNCCL_PACKAGE ${NV_LIBNCCL_PACKAGE_NAME}-${NV_LIBNCCL_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{cuda.version.minor }}
{% endif %}
{% else %}
    {% set nccl_package = 0 %}
{% endif -%}

{% endif %}
{% if "ppc64le" in cuda %}
FROM base as base-ppc64le

ENV NV_CUDA_LIB_VERSION {{ cuda.ppc64le.components.libraries.version }}
ENV NV_NVTX_VERSION {{ cuda.ppc64le.components.nvtx.version }}
ENV NV_LIBNPP_VERSION {{ cuda.ppc64le.components.libnpp.version }}

{% if "libcublas" in cuda.ppc64le.components %}
    {% if cuda.version.major_minor in ["10.1", "10.2"] %}
        {% set suffix = "10" %}
    {% endif %}
{% endif %}
ENV NV_LIBCUBLAS_PACKAGE_NAME libcublas{{ suffix }}
ENV NV_LIBCUBLAS_VERSION {{ cuda.ppc64le.components.libcublas.version }}
ENV NV_LIBCUBLAS_PACKAGE ${NV_LIBCUBLAS_PACKAGE_NAME}-${NV_LIBCUBLAS_VERSION}

{% if "libnccl2" in cuda.ppc64le.components %}
{% if "source" in  cuda.ppc64le.components.libnccl2 %}
    {% set nccl_dl_package = 1 %}
ENV NV_LIBNCCL_DL_BASENAME nccl_{{ cuda.ppc64le.components.libnccl2.version }}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}_ppc64le.txz
ENV NV_LIBNCCL_DL_SHA256SUM {{ cuda.ppc64le.components.libnccl2.sha256sum }}
ENV NV_LIBNCCL_DL_SOURCE {{ cuda.ppc64le.components.libnccl2.source }}
{% else %}
    {% set nccl_package = 1 %}
ENV NV_LIBNCCL_PACKAGE_NAME libnccl
ENV NV_LIBNCCL_PACKAGE_VERSION {{ cuda.ppc64le.components.libnccl2.version }}
ENV NCCL_VERSION {{ cuda.ppc64le.components.libnccl2.version[:-2] }}
ENV NV_LIBNCCL_PACKAGE ${NV_LIBNCCL_PACKAGE_NAME}-${NV_LIBNCCL_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{cuda.version.minor }}
{% endif %}
{% endif %}
{% else %}
    {% set nccl_package = 0 %}
{% endif %}

FROM base-${TARGETARCH}

ARG TARGETARCH

LABEL maintainer "NVIDIA CORPORATION <sw-cuda-installer@nvidia.com>"

{% if cuda.version.major_minor == "10.0" or cuda.version.major_minor == "10.1" %}
# setopt flag prevents yum from auto upgrading. See https://gitlab.com/nvidia/container-images/cuda/-/issues/88
RUN yum install --setopt=obsoletes=0 -y \
{% else %}
RUN yum install -y \
{% endif %}
    cuda-libraries-{{ cuda.version.major }}-{{ cuda.version.minor }}-${NV_CUDA_LIB_VERSION} \
    {% if cuda.version.major_minor == "9.0" %}
    cuda-cublas-{{ cuda.version.major }}-{{ cuda.version.minor }}={{ cuda.version.release_label }}.4-1 \
    {% elif cuda.version.major_minor == "10.0" %}
    cuda-cublas-$CUDA_PKG_VERSION \
    {% endif %}
    {% if (cuda.version.major_minor == "9.2") or cuda.version.major | int >= 10 %}
    cuda-nvtx-{{ cuda.version.major }}-{{ cuda.version.minor }}-${NV_NVTX_VERSION} \
    {% endif %}
    cuda-npp-{{ cuda.version.major }}-{{ cuda.version.minor }}-${NV_LIBNPP_VERSION} \
    ${NV_LIBCUBLAS_PACKAGE} \
    {% if nccl_package %}
    ${NV_LIBNCCL_PACKAGE} \
    {% endif %}
    && yum clean all \
    && rm -rf /var/cache/yum/*

{% if cuda.version.major_minor in ["10.0", "10.1"] and cuda.os.distro != "ubi" %}
RUN yum install -y yum-plugin-versionlock && yum versionlock ${NV_LIBCUBLAS_PACKAGE_NAME}

{% endif -%}

{% if nccl_dl_package %}
RUN yum install -y xz && NCCL_DOWNLOAD_SUM=${NV_LIBNCCL_DL_SHA256SUM} \
    && curl -fsSL ${NV_LIBNCCL_DL_SOURCE} -O \
    && echo "$NCCL_DOWNLOAD_SUM  ${NV_LIBNCCL_DL_BASENAME}" | sha256sum -c - \
    && unxz ${NV_LIBNCCL_DL_BASENAME} \
    && tar --no-same-owner --keep-old-files -xvf "${NV_LIBNCCL_DL_BASENAME:0:-4}.tar" -C /usr/local/cuda/lib64/ --strip-components=2 --wildcards '*/lib/libnccl.so.*' \
    && tar --no-same-owner --keep-old-files -xvf  "${NV_LIBNCCL_DL_BASENAME:0:-4}.tar" -C /usr/lib64/pkgconfig/ --strip-components=3 --wildcards '*/lib/pkgconfig/*' \
    && rm -f "${NV_LIBNCCL_DL_BASENAME:0:-4}.tar" \
    && ldconfig \
    && yum clean all \
    && rm -rf /var/cache/yum/*
{% endif -%}
